package connection

import (
	"errors"
	"net"
	"time"

	"reflect"

	"math/rand"

	"github.com/VolantMQ/volantmq/packet"
	"go.uber.org/zap"
	"sync/atomic"
	"sync"
	"fmt"
)

func (s *Type) gPush(pkt packet.Provider) {
	s.push(pkt, s.txGMessages, &s.txGBMessages)
}

func (s *Type) qPush(pkt packet.Provider) {
	s.push(pkt, s.txQMessages, &s.txQBMessages)
}

func (s *Type)push(pkt packet.Provider, queue chan packet.Provider, bQueue *sync.Map){
	if atomic.LoadUint32(&s.txRunning) == 1 {
		s.tLock.RLock()
		defer s.tLock.RUnlock()
	TRY_PUSH:
		// 防止txRoutine其实已经退出及队列已满产生死锁
		select {
		case queue <- pkt:
		default:
			if atomic.LoadUint32(&s.txRunning) == 0 {
				// 客户端已经下线,消息放入离线消息缓存队列.
				bQueue.Store(&pkt, pkt)
			} else {
				s.log.Warn("Message transfer buffer queue is full, either the size is too small or the client is too slow to ack the message.",
					zap.String("ClientID", s.ID),
					zap.Int("QueueSize", cap(queue)),
				)
				time.Sleep(time.Second * 1)
				// 再尝试一次写入队列或者离线buffer
				goto TRY_PUSH
			}
		}
	} else {
		bQueue.Store(&pkt, pkt)
	}
}


func (s *Type) flushBuffers(buf net.Buffers) error {
	s.log.Debug("flush buffered messages to client", zap.Int("count", len(buf)), zap.String("ClientID", s.ID))
	_, e := buf.WriteTo(s.Conn)
	buf = net.Buffers{}
	// todo metrics
	return e
}

func (s *Type) packetFitsSize(value interface{}) bool {
	var sz int
	var err error
	if obj, ok := value.(sizeAble); !ok {
		s.log.Fatal("Object does not belong to allowed types",
			zap.String("ClientID", s.ID),
			zap.String("Type", reflect.TypeOf(value).String()))
	} else {
		if sz, err = obj.Size(); err != nil {
			s.log.Error("Couldn't calculate message size", zap.String("ClientID", s.ID), zap.Error(err))
			return false
		}
	}

	// ignore any packet with size bigger than negotiated
	if sz > int(s.MaxTxPacketSize) {
		s.log.Warn("Ignore packet with size bigger than negotiated with client",
			zap.String("ClientID", s.ID),
			zap.Uint32("negotiated", s.MaxTxPacketSize),
			zap.Int("actual", sz))
		return false
	}

	return true
}

func (s *Type) txRoutine() {
	var err error

	defer func() {
		s.txWg.Done()
		atomic.StoreUint32(&s.txRunning, 0)
		s.tLock.Lock()
		close(s.txQMessages)
		close(s.txGMessages)
		s.tLock.Unlock()
		if err != nil {
			s.onConnectionClose(true, nil)
		}
	}()

	trySend := func(pkt packet.Provider)error{
		switch _p := pkt.(type) {
		case *packet.Publish:
			s.log.Debug("Preparing sending message to client.",
				zap.String("ClientID", s.ID),
				zap.Int64("MsgCreatedAt", _p.GetCreateTimestamp()),
				zap.String("Topic", _p.Topic()),
			)
			if _p.Expired(true) {
				pkt = nil
				s.log.Debug("Message to be sent to client is expired.",
					zap.String("ClientID", s.ID),
					zap.Int64("MsgCreatedAt", _p.GetCreateTimestamp()),
					zap.String("Topic", _p.Topic()),
				)
			} else {
				s.setTopicAlias(_p)
			}
		}

		if pkt != nil {
			if ok := s.packetFitsSize(pkt); ok {
				if buf, e := packet.Encode(pkt); e != nil {
					s.log.Error("Message encode", zap.String("ClientID", s.ID), zap.Error(e))
					return nil
				} else {
					s.log.Debug("Begin sending message to client.",
						zap.String("ClientID", s.ID),
						zap.Int64("MsgCreatedAt", pkt.GetCreateTimestamp()),
					)
					toN := len(buf)
					for sN := 0;; {
						n, e := s.Conn.Write(buf[sN:])
						if e != nil {
							return e
						}
						sN += n
						if sN >= toN {
							s.log.Debug("Finished sending message to client.",
								zap.String("ClientID", s.ID),
								zap.Int64("MsgCreatedAt", pkt.GetCreateTimestamp()),
							)
							return nil
						}
					}
				}
			}
		}
		return nil
	}

	for {
		var m packet.Provider
		select {
		case <-s.quit:
			err = errors.New("exit")
			return
		// QoS > 0的消息
		case pkt := <-s.txQMessages:
			switch t := pkt.(type) {
			case *packet.Publish:
				// try acquire packet id
				id, err := s.flowAcquire()
				t.QoS()
				if err != nil {
					if err != errExit {
						s.log.Warn("Packet flowAccquire err.", zap.Error(err), zap.String("ClientID", s.ID))
					}
					// 放回队列,稍后再次进行处理.
					s.qPush(pkt)
					continue
				}
				t.SetPacketID(id)
				m = t
			case *unacknowledged:
				m = t.Provider
			default:
				s.log.Warn("Unexpected packet type to transmit. drop it", zap.String("Packet", fmt.Sprintf("%#v", pkt)))
				continue
			}

			if m != nil{
				// 计入未ack的队列
				s.pubOut.store(m)
			}
		// QoS == 0的消息
		case pkt := <-s.txGMessages:
			m = pkt.(packet.Provider)
		}
		// 推送消息到客户端,如果发生错误则终止所有推送,进入关闭过程.
		// QoS > 0 的消息下次将被从未ack的队列重新发送. QoS == 0的消息被抛弃,不再发送.
		if m != nil {
			if err = trySend(m); err != nil {
				s.log.Warn("Message transmit error.",
					zap.String("ClientID", s.ID), zap.String("Packet", fmt.Sprintf("%#v", m)),zap.Error(err),
					)
				return
			}
		}
	}
}

func (s *Type) setTopicAlias(pkt *packet.Publish) {
	if s.MaxTxTopicAlias > 0 {
		var ok bool
		var alias uint16
		if alias, ok = s.txTopicAlias[pkt.Topic()]; !ok {
			if s.topicAliasCurrMax < s.MaxTxTopicAlias {
				s.topicAliasCurrMax++
				alias = s.topicAliasCurrMax
				ok = true
			} else {
				alias = uint16(rand.Intn(int(s.MaxTxTopicAlias)) + 1)
			}
		} else {
			ok = false
		}

		if err := pkt.PropertySet(packet.PropertyTopicAlias, alias); err == nil && !ok {
			pkt.SetTopic("") // nolint: errcheck
		}
	}
}
